This script analyzes filtered mAb escape data¶

In [1]:
# this cell is tagged as parameters for `papermill` parameterization
binding_data = None
HENV103_filter = None
HENV117_filter = None
HENV26_filter = None
HENV32_filter = None
m102_filter = None
nAH1_filter = None

altair_config = None
nipah_config = None
escape_bubble_plot = None
bubble_1_mut_plot = None
mab_line_escape_plot = None
aggregate_mab_and_binding = None
aggregate_mab_and_niv_polymorphism = None
binding_vs_escape = None

mab_plot_top = None
mab_plot_all = None
In [2]:
# Parameters
nipah_config = "nipah_config.yaml"
altair_config = "data/custom_analyses_data/theme.py"
HENV103_filter = "results/filtered_data/HENV103_escape_filtered.csv"
HENV117_filter = "results/filtered_data/HENV117_escape_filtered.csv"
HENV26_filter = "results/filtered_data/HENV26_escape_filtered.csv"
HENV32_filter = "results/filtered_data/HENV32_escape_filtered.csv"
m102_filter = "results/filtered_data/m102_escape_filtered.csv"
nAH1_filter = "results/filtered_data/nAH1_escape_filtered.csv"
binding_data = "results/filtered_data/E2_binding_filtered.csv"
escape_bubble_plot = "results/images/escape_bubble_plot.html"
bubble_1_mut_plot = "results/images/escape_bubble_1_mut_plot.html"
overlap_escape_plot = "results/images/overlap_escape_plot.html"
mab_line_escape_plot = "results/images/mab_line_escape_plot.html"
mab_plot_top = "results/images/mab_plot_top.html"
mab_plot_all = "results/images/mab_plot_all.html"
aggregate_mab_and_binding = "results/images/aggregate_mab_and_binding.html"
binding_vs_escape = "results/images/binding_vs_escape.html"
aggregate_mab_and_niv_polymorphism = (
    "results/images/aggregate_mab_and_niv_polymorphism.html"
)
In [3]:
if binding_data is None:
    print('this is being run manually')
else:
    print('papermill!')
papermill!
In [4]:
import math
import os
import re

import altair as alt

import numpy as np

import pandas as pd

import scipy.stats

import Bio.SeqIO
import yaml
import matplotlib
matplotlib.rcParams['svg.fonttype'] = 'none'

from Bio import PDB
import dmslogo
from dmslogo.colorschemes import CBPALETTE
from dmslogo.colorschemes import ValueToColorMap
In [5]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if os.getcwd() == '/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/':
    pass
    print("Already in correct directory")
else:
    os.chdir("/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/")
    print("Setup in correct directory")
Setup in correct directory
In [6]:
#altair_config = 'data/custom_analyses_data/theme.py'
#nipah_config = 'nipah_config.yaml'

#binding_data = 'results/filtered_data/E2_binding_filtered.csv'

#HENV103_filter = 'results/filtered_data/HENV103_escape_filtered.csv'
#HENV117_filter = 'results/filtered_data/HENV117_escape_filtered.csv'
#HENV26_filter = 'results/filtered_data/HENV26_escape_filtered.csv'
#HENV32_filter = 'results/filtered_data/HENV32_escape_filtered.csv'
#m102_filter = 'results/filtered_data/m102_escape_filtered.csv'
#nAH1_filter = 'results/filtered_data/nAH1_escape_filtered.csv'
#
#escape_bubble_plot = 'results/images/escape_bubble_plot.html'
#bubble_1_mut_plot = 'results/images/escape_bubble_1_mut_plot.html'
#overlap_escape_plot = 'results/images/overlap_escape_plot.html'

#m102_heat = 'results/images/m102_heatmap.html'
#HENV26_heat = 'results/images/HENV26_heatmap.html'
#HENV32_heat = 'results/images/HENV32_heatmap.html'
#nAH1_heat = 'results/images/nAH1_heatmap.html'
#HENV117_heat = 'results/images/HENV117_heatmap.html'
#HENV103_heat = 'results/images/HENV103_heatmap.html'
In [7]:
if altair_config:
    with open(altair_config, 'r') as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

Make logo plots¶

Filtering parameters¶

In [8]:
# Make a dataframe with all the mutants with low entry scores for masking later in script
func_scores_E3 = pd.read_csv('../Nipah_Malaysia_RBP_DMS/results/func_effects/averages/CHO_EFNB3_low_func_effects.csv')
func_scores_E3_low_effect = func_scores_E3[
    (func_scores_E3['effect'] < config['min_func_effect_for_ab']) &
    (func_scores_E3['times_seen'] > config['func_times_seen_cutoff']) &
    (func_scores_E3['site'] != 603) &
    (func_scores_E3['mutant'] != '-') &
    (func_scores_E3['mutant'] != '*')
]
display(func_scores_E3_low_effect)
site wildtype mutant effect effect_std times_seen n_selections
13 71 Q P -3.506 0.00000 6.714 7
23 72 N C -2.485 0.62820 5.429 7
34 72 N P -3.545 0.00000 6.714 7
39 72 N V -3.084 0.06917 6.000 7
55 73 Y P -3.412 0.90800 3.833 6
... ... ... ... ... ... ... ...
10753 597 I S -2.796 0.08909 3.000 7
10754 597 I T -3.526 0.00000 6.857 7
10759 598 P C -2.300 0.60270 4.571 7
10776 598 P W -3.177 0.00000 8.143 7
10777 598 P Y -2.057 0.20350 3.286 7

2711 rows × 7 columns

Read in filtered antibody escape files and combine.¶

In [9]:
HENV103 = pd.read_csv(HENV103_filter)
HENV117 = pd.read_csv(HENV117_filter)
HENV26 = pd.read_csv(HENV26_filter)
HENV32 = pd.read_csv(HENV32_filter)
m102 = pd.read_csv(m102_filter)
nAH1 = pd.read_csv(nAH1_filter)

# Combine all the individual filtered antibody escape files
combined_df = pd.concat([HENV103,HENV117,HENV26,HENV32,m102,nAH1])
combined_df = combined_df[['site','wildtype','mutant','mutation','effect','escape_median','escape_std','times_seen_ab','show_site','ab']]
display(combined_df)

# Make a separate dataframe that only has the top sites
filtered_df = combined_df.query('show_site == True')
filtered_df = filtered_df[filtered_df['escape_median'] >= config['min_escape_cutoff']]
display(filtered_df)
site wildtype mutant mutation effect escape_median escape_std times_seen_ab show_site ab
0 71 Q D Q71D -0.7886 -0.002836 0.05669 3.000 False HENV-103
1 71 Q E Q71E 0.4129 -0.044930 0.12250 3.333 False HENV-103
2 71 Q F Q71F -0.3917 0.026290 0.01273 2.333 False HENV-103
3 71 Q G Q71G -0.3752 0.012700 0.01622 3.000 False HENV-103
4 71 Q H Q71H -0.1068 0.015420 0.22880 2.667 False HENV-103
... ... ... ... ... ... ... ... ... ... ...
6911 602 T R T602R 0.5666 -0.023040 0.06005 6.667 False nAH1.3
6912 602 T S T602S 0.2874 0.160600 0.24200 3.667 False nAH1.3
6913 602 T V T602V 0.4577 0.134100 0.11030 5.000 False nAH1.3
6914 602 T W T602W 0.5192 0.060960 0.13720 5.667 False nAH1.3
6915 602 T Y T602Y 0.5354 0.121700 0.11090 5.000 False nAH1.3

41447 rows × 10 columns

site wildtype mutant mutation effect escape_median escape_std times_seen_ab show_site ab
861 151 P A P151A -0.4827 0.3086 0.2750 4.000 True HENV-103
863 151 P G P151G -0.3962 0.4608 0.2998 4.333 True HENV-103
864 151 P I P151I -1.7310 0.3507 0.3999 2.667 True HENV-103
866 151 P L P151L -1.5310 0.3130 0.1990 9.333 True HENV-103
902 154 K I K154I -0.9768 0.5460 0.3375 5.333 True HENV-103
... ... ... ... ... ... ... ... ... ... ...
6891 601 C P C601P -1.0200 0.4141 0.2796 2.333 True nAH1.3
6892 601 C R C601R -1.3340 0.4948 0.3901 11.330 True nAH1.3
6893 601 C S C601S 0.4662 0.6669 0.2217 3.000 True nAH1.3
6894 601 C T C601T -0.3032 0.6831 0.2290 4.333 True nAH1.3
6896 601 C Y C601Y -1.1710 0.5645 0.2494 14.330 True nAH1.3

633 rows × 10 columns

In [10]:
def identify_escape_sites(df, ab):
    subset = df[(df['ab'] == ab)]
    unique_sites = list(subset['site'].unique())
    return unique_sites

abs = ['HENV-26', 'HENV-103', 'HENV-32', 'HENV-117', 'm102.4', 'nAH1.3']
sites_dict = {}  # Create an empty dictionary to store the results

for ab in abs:
    sites_dict[ab] = identify_escape_sites(filtered_df, ab)

display(sites_dict) #need site dict for later
{'HENV-26': [165,
  166,
  167,
  170,
  171,
  172,
  176,
  204,
  233,
  257,
  312,
  403,
  490,
  491,
  492,
  494,
  497,
  501,
  529,
  530,
  531,
  580,
  586,
  589],
 'HENV-103': [151,
  154,
  156,
  160,
  174,
  176,
  205,
  242,
  258,
  259,
  260,
  261,
  264,
  265,
  268,
  273,
  274,
  275,
  277],
 'HENV-32': [151,
  153,
  154,
  160,
  176,
  199,
  200,
  201,
  205,
  207,
  265,
  268,
  274,
  275,
  277,
  402,
  509,
  534,
  549,
  552,
  555,
  556,
  593,
  596],
 'HENV-117': [166,
  170,
  171,
  172,
  204,
  208,
  217,
  233,
  257,
  270,
  351,
  555,
  578,
  580,
  582,
  583,
  586,
  587,
  588,
  589,
  592],
 'm102.4': [165,
  171,
  172,
  209,
  218,
  239,
  243,
  270,
  305,
  507,
  509,
  529,
  532,
  542,
  555,
  558,
  559,
  577,
  580,
  582,
  586,
  587,
  588,
  589],
 'nAH1.3': [184,
  185,
  186,
  187,
  188,
  189,
  190,
  447,
  448,
  449,
  450,
  451,
  468,
  513,
  516,
  517,
  518,
  519,
  520,
  601]}

Plot bubble chart showing mAb escape for individual mutants by functional score for both E2 or E3¶

In [11]:
order_ab = ['m102.4','HENV-26','HENV-117','HENV-103','HENV-32','nAH1.3']

def generate_chart(df):
    variant_selector = alt.selection_point(
    on="mouseover",
    empty=False,
    fields=["site"],
    value=1
    )  
    chart = alt.Chart(df,title=alt.Title('Top Antibody Escape Mutations',subtitle='Hover over points to see escape at same site')).mark_point(filled=True, stroke='black').encode(
        x=alt.X('ab:O',sort=order_ab, title='Antibody', axis=alt.Axis(labelAngle=-90,grid=False)),
        y=alt.Y('effect:Q', title='Cell Entry of Top Escape', axis=alt.Axis(grid=True, tickCount=4,values=[0.5, 0, -0.5, -1, -1.5, -2])),
        size=alt.Size('escape_median', legend=alt.Legend(title='Mean Escape By Mutation')),
        xOffset='random:Q',
        tooltip=['site','wildtype','mutant','ab', 'effect','escape_median','escape_std'],
        color=alt.Color('ab').legend(None),
        opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.4)),
        strokeWidth=alt.condition(variant_selector,alt.value(2),alt.value(0))
    ).transform_calculate(
        random="sqrt(-1*log(random()))*cos(2*PI*random())"
        #random='random'
    ).properties(
        height=300,
        width='container'
    ).add_params(variant_selector)

    return chart


escape_bubble = generate_chart(filtered_df)
escape_bubble.display()
escape_bubble.save(escape_bubble_plot)

Now summarize by number of mutations between wildtype and mutant codons¶

In [12]:
# Load in wt nucleotide sequence (which is different than the 'wt' sequence from Library as it was codon optimized)
niv_m_wt = str(Bio.SeqIO.read('data/custom_analyses_data/alignments/wild_type_seq.fasta', 'fasta').seq)

codon_table = {
    "ATA":"I", "ATC":"I", "ATT":"I", "ATG":"M",
    "ACA":"T", "ACC":"T", "ACG":"T", "ACT":"T",
    "AAC":"N", "AAT":"N", "AAA":"K", "AAG":"K",
    "AGC":"S", "AGT":"S", "AGA":"R", "AGG":"R",
    "CTA":"L", "CTC":"L", "CTG":"L", "CTT":"L",
    "CCA":"P", "CCC":"P", "CCG":"P", "CCT":"P",
    "CAC":"H", "CAT":"H", "CAA":"Q", "CAG":"Q",
    "CGA":"R", "CGC":"R", "CGG":"R", "CGT":"R",
    "GTA":"V", "GTC":"V", "GTG":"V", "GTT":"V",
    "GCA":"A", "GCC":"A", "GCG":"A", "GCT":"A",
    "GAC":"D", "GAT":"D", "GAA":"E", "GAG":"E",
    "GGA":"G", "GGC":"G", "GGG":"G", "GGT":"G",
    "TCA":"S", "TCC":"S", "TCG":"S", "TCT":"S",
    "TTC":"F", "TTT":"F", "TTA":"L", "TTG":"L",
    "TAC":"Y", "TAT":"Y", "TAA":"*", "TAG":"*",
    "TGC":"C", "TGT":"C", "TGA":"*", "TGG":"W"
}

def find_closest_codon(wt_codon, mutant_aa):
    mutant_codons = [codon for codon, aa in codon_table.items() if aa == mutant_aa]
    min_mutations = 3  # Maximum mutations possible
    closest_codon = None
    for m_codon in mutant_codons:
        mutations = sum([1 for c1, c2 in zip(wt_codon, m_codon) if c1 != c2])
        if mutations < min_mutations:
            min_mutations = mutations
            closest_codon = m_codon
    return closest_codon, min_mutations

# Function to extract codon for a given site
def extract_codon(site):
    idx = (site - 1) * 3  
    return niv_m_wt[idx: idx + 3]

def extract_codon_niv_b(site):
    idx = (site - 1) * 3  
    return niv_m_wt[idx: idx + 3]

def apply_codon_to_df(df,extract_func):
    df['wt_codon'] = df['site'].apply(extract_func)
    df['closest_mutant_codon'] = df.apply(lambda row: find_closest_codon(row['wt_codon'], row['mutant'])[0], axis=1)
    df['min_mutations'] = df.apply(lambda row: find_closest_codon(row['wt_codon'], row['mutant'])[1], axis=1)
    return df

combined_df = apply_codon_to_df(combined_df,extract_codon)
filtered_df = apply_codon_to_df(filtered_df,extract_codon)
In [13]:
def generate_chart_all(df):
    variant_selector = alt.selection_point(
        on="mouseover",
        empty=False,
        fields=["site"],
        value=1
    )  
    radio = alt.binding_radio(options=[1, 2, 3], labels=['1','2','3'], name='Min Mutations:')
    mutation_selector = alt.param(name="MutationSelector", value=1, bind=radio)

    slider = alt.binding_range(min=0.2, max=1.6, step=0.1, name="median_escape")
    selector = alt.param(name="SelectorName", value=0.2, bind=slider)
    
    chart = alt.Chart(df,title=alt.Title('Antibody Escape Mutations',subtitle='Hover over points to see escape at same site')).mark_point(filled=True, stroke='black').encode(
        x=alt.X('ab:O',sort=order_ab, title='Antibody', axis=alt.Axis(labelAngle=-90,grid=False)),
        y=alt.Y('effect:Q', title='Cell Entry of Top Escape', axis=alt.Axis(grid=True, tickCount=4,values=[0.5, 0, -0.5, -1, -1.5, -2])),
        size=alt.Size('escape_median', legend=alt.Legend(title='Mean Escape By Mutation')),
        xOffset='random:Q',
        tooltip=['site','wildtype','mutant','ab', 'effect','escape_median','escape_std'],
        color=alt.Color('ab').legend(None),
        opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.4)),
        strokeWidth=alt.condition(variant_selector,alt.value(2),alt.value(0))
    ).transform_calculate(
        random="sqrt(-1*log(random()))*cos(2*PI*random())"
        #random='random'
    ).properties(
        height=300,
        width='container'
    ).add_params(variant_selector,mutation_selector,selector).transform_filter(
        (alt.datum.min_mutations == mutation_selector) &
        (alt.datum.escape_median > selector)
    )

    return chart


all_escape = generate_chart_all(combined_df.query('escape_median >= 0.2'))
all_escape.display()
In [14]:
def plot_escape_and_mutations_away(df):
    variant_selector = alt.selection_point(
    on="mouseover",
    empty=False,
    fields=["site"],
    value=1
    )
    radio = alt.binding_radio(options=[1, 2, 3], name='Min Mutations:')
    mutation_selector = alt.param(name="MutationSelector", value=1, bind=radio)

    chart = alt.Chart(df,title=alt.Title('Top Antibody Escape Mutations',subtitle='By # of nucleotide mutations away')).mark_point(filled=True,stroke='black').encode(
        x=alt.X('ab:O',sort=order_ab,title=None,axis=alt.Axis(labelAngle=-90,grid=False)),
        y=alt.Y('effect:Q', title='Cell Entry of Escape Mutants',axis=alt.Axis(grid=True,tickCount=4,values=[0.5, 0, -0.5, -1, -1.5, -2])),   # 'Q' denotes a quantitative variable
        size=alt.Size('escape_median',legend=alt.Legend(title='Escape of Mutant')),
        xOffset='random:Q',
        tooltip=['ab','effect','escape_median','site','mutant'],
        opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.4)),
        strokeWidth=alt.condition(variant_selector,alt.value(2),alt.value(0)),
        color=alt.Color('ab:N').legend(None),
    ).transform_calculate(
        #random='random()'
        random="sqrt(-2*log(random()))*cos(2*PI*random())"
    ).properties(
        height=300,
        width='container'
    ).add_params(variant_selector,mutation_selector).transform_filter(
            (alt.datum.min_mutations == mutation_selector)
    )
    return chart

bubble_plot_1_mut_away = plot_escape_and_mutations_away(filtered_df)
bubble_plot_1_mut_away.display()
bubble_plot_1_mut_away.save(bubble_1_mut_plot)
In [15]:
def find_overlapping_escape(df):
    slider = alt.binding_range(min=config['min_func_effect_for_ab'], max=0, step=0.25, name="effect")
    selector = alt.param(name="SelectorName", value=-4, bind=slider)

    radio = alt.binding_radio(options=[1, 2, 3], name='Min Mutations:')
    mutation_selector = alt.param(name="MutationSelector", value=1, bind=radio)
    
    df_filtered = df
    # Group by 'site' and 'mutant', count the unique 'ab' values for each group
    grouped = df_filtered.groupby(['site', 'mutant'])['ab'].nunique().reset_index()

    # Filter groups where the count of unique 'ab' values is at least 2
    result = grouped[grouped['ab'] >= 2]
    
    # Merge the result with the original dataframe to get the full rows
    df_result = pd.merge(df, result[['site', 'mutant']], on=['site', 'mutant'])
    df_result['mutation_number'] = df_result['mutation'].str.extract('(\d+)').astype(int)
    base = (
        alt.Chart(df_result,title=alt.Title('Shared antibody escape mutations')).mark_rect().encode(
            x=alt.X('mutation:O', title='Site',sort=alt.EncodingSortField(field='mutation_number'), axis=alt.Axis(labelAngle=-90,grid=False)),
            y=alt.Y('ab:O', title='Mutant',sort=order_ab,axis=alt.Axis(grid=False)),  # Apply custom sort order here
            color='escape_median',
            #opacity=alt.condition(alt.datum.effect >= selector, alt.value(1), alt.value(0)),
            tooltip=['site','wildtype','mutant','escape_median','min_mutations'],
        ).properties(
            width=alt.Step(30),
            height=alt.Step(20) 
        )
    ).add_params(selector,mutation_selector).transform_filter(
            (alt.datum.effect >= selector) & (alt.datum.min_mutations == mutation_selector)
    )
    return base
overlap_escape = find_overlapping_escape(filtered_df)
overlap_escape.display()
overlap_escape.save(overlap_escape_plot)

Line plots of escape¶

In [16]:
def plot_line_escape(df):
    variant_selector = alt.selection_point(
        on="mouseover",
        empty=False,
        fields=["site"],
        value=0
    )  
    # Group by 'site' and 'mutant', count the unique 'ab' values for each group
    summed = df.groupby(['site','ab'])['escape_median'].sum().reset_index()
    empty_chart = []
    ab_list = ['m102.4', 'HENV-26', 'HENV-117', 'HENV-103', 'HENV-32', 'nAH1.3']
    for idx, ab in enumerate(ab_list):
        tmp_df = summed[summed['ab'] == ab]
        #color = '#1f4e79'
        if ab in ['m102.4','HENV-26','HENV-117']:
            color = "#1f4e79"
        if ab in ['HENV-103','HENV-32']:
            color = "#ff7f0e"
        if ab in ['nAH1.3']:
            color = "#2ca02c"
            
        # Conditionally set the x-axis labels and title for the last plot
        is_last_plot = idx == len(ab_list) - 1
        x_axis = alt.Axis(values=[100, 200, 300, 400, 500, 600], tickCount=6, labelAngle=-90, grid=True,
                          labelExpr="datum.value % 100 === 0 ? datum.value : ''",
                          title="Site" if is_last_plot else None,
                          labels=is_last_plot)  # Only show labels for the last plot
        base = (
            alt.Chart(tmp_df).mark_line(size=1, color=color).encode(
                x=alt.X('site:O', axis=x_axis),
                y=alt.Y('escape_median', title=f'{ab}', axis=alt.Axis(grid=True)),
            ).properties(
                width=1000,
                height=100
            )
        )
        point = base.mark_point(color='black',size=10,filled=True).encode(
            x=alt.X('site:O', axis=x_axis),
            y=alt.Y('escape_median', title=f'{ab}', axis=alt.Axis(grid=True,)),
            size=alt.condition(variant_selector, alt.value(100),alt.value(15)),
            color=alt.condition(variant_selector, alt.value('black'), alt.value(color)),
            tooltip=['site','escape_median'],
        ).properties(
            width=1000,
            height=100,
        ).add_params(variant_selector)
        chart = base + point
        empty_chart.append(chart)
    
    # Use configure_concat to adjust spacing between vertically concatenated plots
    combined_chart = alt.vconcat(*empty_chart, spacing=1).resolve_scale(y='independent', x='shared', color='independent').properties(title=alt.Title('Summed Antibody Escape by Site',subtitle='Colored by epitope'))
    
    return combined_chart

tmp_line = plot_line_escape(combined_df)
tmp_line.display()
tmp_line.save(mab_line_escape_plot)

Now calculate atomic distances between escape sites and closest amino acid in heavy and light chains¶

In [17]:
def calculate_min_distances(pdb_path, source_chain_id, target_chain_ids, name):
    # Initialize the PDB parser and load the structure
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure('structure_id', pdb_path)

    source_chain = structure[0][source_chain_id]
    target_chains = [structure[0][chain_id] for chain_id in target_chain_ids]

    data = []

    for residueA in source_chain:
        if residueA.resname in ["HOH", "WAT", "IPA", "NAG"]:
            continue

        min_distance = float('inf')
        closest_residueB = None
        closest_chain_id = None
        residues_within_4 = 0

        for target_chain in target_chains:
            for residueB in target_chain:
                if residueB.resname in ["HOH", "WAT", "IPA"]:
                    continue

                # Check for residues within 4 angstroms
                is_within_4 = False
                for atomA in residueA:
                    for atomB in residueB:
                        distance = atomA - atomB
                        if distance < min_distance:
                            min_distance = distance
                            closest_residueB = residueB
                            closest_chain_id = target_chain.get_id()
                        if distance < 4:
                            is_within_4 = True
                if is_within_4:
                    residues_within_4 += 1

        data.append({
            'wildtype': residueA.resname,
            'site': residueA.id[1],
            'chain': closest_chain_id,
            'residue': closest_residueB.id[1],
            'residue_name': closest_residueB.resname,
            'distance': min_distance,
            'residues_within_4': residues_within_4,
            'ab': name
        })

    # Convert data to pandas DataFrame
    df = pd.DataFrame(data)
    return df

def check_file(input_path,source_chain,target_chain,name,output_path):

    file_path = output_path
    
    if not os.path.exists(file_path):
        print(f'File {name} does not exist, running calculation')
        output_df = calculate_min_distances(input_path,source_chain,target_chain,name)
        print(f'done calculating for {file_path}')
        output_df.to_csv(output_path,index=False)
        return output_df
    else:
        print("File already exists,loading from disk")
        output_df = pd.read_csv(output_path)
        return output_df



pdb_path_26 = 'data/custom_analyses_data/crystal_structures/6vy5.pdb'
source_chain_26 = 'A'
target_chains_26 = ['H', 'L']
output_path_26 = 'results/distances/df_HENV26_atomic_distances.csv'

pdb_path_32 = 'data/custom_analyses_data/crystal_structures/6vy4.pdb'
source_chain_32 = 'A'
target_chains_32 = ['H', 'L']
output_path_32 = 'results/distances/df_HENV32_atomic_distances.csv'

pdb_path_nah = 'data/custom_analyses_data/crystal_structures/7txz.pdb'
source_chain_nah = 'A'
target_chains_nah = ['F', 'E']
output_path_nah = 'results/distances/df_nAH_atomic_distances.csv'

pdb_path_m102 = 'data/custom_analyses_data/crystal_structures/6cmg.pdb'
source_chain_m102 = 'A'
target_chains_m102 = ['B', 'C']
output_path_m102 = 'results/distances/df_m102_atomic_distances.csv'




df_HENV26 = check_file(pdb_path_26, source_chain_26, target_chains_26, 'HENV-26',output_path_26)
df_HENV32 = check_file(pdb_path_32, source_chain_32, target_chains_32, 'HENV-32',output_path_32)
df_nah = check_file(pdb_path_nah, source_chain_nah, target_chains_nah, 'nAH1.3',output_path_nah)
df_nah['chain'].replace({'E': 'H', 'F': 'L'}, inplace=True) # Fix naming so consistent heavy and light chain naming
df_m102 = check_file(pdb_path_m102, source_chain_m102, target_chains_m102, 'm102.4',output_path_m102)
df_m102['chain'].replace({'C': 'H', 'B': 'L'}, inplace=True) # Fix naming so consistent heavy and light chain naming
File already exists,loading from disk
File already exists,loading from disk
File already exists,loading from disk
File already exists,loading from disk
In [18]:
def find_close_mab_sites(df,name):
    unique_sites = df.query('distance <= 4')['site'].unique()
    mab_site_list = list(unique_sites)
    print(f'Close sites for mAb {name} are: {mab_site_list}')
    return mab_site_list

### First find RBP sites that are close to mAb residues
nah_close = find_close_mab_sites(df_nah,'nAH1.3')
HENV26_close = find_close_mab_sites(df_HENV26,'HENV-26')
HENV32_close = find_close_mab_sites(df_HENV32,'HENV-32')
m102_close = find_close_mab_sites(df_m102,'m102.4')

### Now combined the close residues AND the top escape sites identified previously
nah_combined_sites = sites_dict['nAH1.3'] + nah_close 
HENV26_combined_sites = sites_dict['HENV-26'] + HENV26_close 
HENV32_combined_sites = sites_dict['HENV-32'] + HENV32_close
m102_combined_sites = sites_dict['m102.4'] + m102_close 
Close sites for mAb nAH1.3 are: [172, 183, 184, 185, 186, 187, 188, 190, 191, 358, 449, 450, 451, 472, 515, 516, 517, 518, 570]
Close sites for mAb HENV-26 are: [389, 401, 403, 404, 458, 488, 489, 490, 491, 492, 494, 497, 501, 504, 505, 506, 528, 529, 530, 531, 532, 533, 555, 556, 557, 581, 586]
Close sites for mAb HENV-32 are: [196, 199, 200, 201, 202, 203, 205, 206, 207, 210, 254, 256, 258, 260, 262, 263, 264, 266, 553]
Close sites for mAb m102.4 are: [239, 240, 241, 242, 305, 458, 488, 489, 490, 504, 505, 506, 507, 530, 532, 533, 555, 557, 559, 579, 580, 581, 588]
In [19]:
def make_distance(df):
    subset_df = df[['site','distance']].copy()
    subset_df['mutant'] = 'distance'
    subset_df['wildtype'] = ''  
    subset_df['effect'] = 'escape_median'  
    subset_df.rename(columns={'distance': 'value'}, inplace=True) 
    return subset_df
        
distance_nah_df = make_distance(df_nah)
distance_26_df = make_distance(df_HENV26)
distance_32_df = make_distance(df_HENV32)
distance_m102_df = make_distance(df_m102)

display(distance_m102_df)
site value mutant wildtype effect
0 176 35.044434 distance escape_median
1 177 31.866014 distance escape_median
2 178 27.842815 distance escape_median
3 179 28.777035 distance escape_median
4 180 28.332012 distance escape_median
... ... ... ... ... ...
423 599 30.802711 distance escape_median
424 600 28.920950 distance escape_median
425 601 27.248772 distance escape_median
426 602 29.020868 distance escape_median
427 603 27.945621 distance escape_median

428 rows × 5 columns

Prepare dataframes for heatmaps¶

In [20]:
def make_empty_df_with_distance(ab,distance_file):
    #print(ab)
    sites = range(71, 603)
    amino_acids = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
    # Create the combination of each site with each amino acid
    data = [{'site': site, 'mutant': aa} for site in sites for aa in amino_acids]
    # Create the DataFrame
    empty_df = pd.DataFrame(data)
    all_sites_df = pd.merge(empty_df,combined_df.query(f'ab == "{ab}"'),on=['site','mutant'],how='left')
    df_melted = all_sites_df.melt(id_vars=['site', 'mutant', 'wildtype'],
                                 value_vars=['escape_median'], 
                                 var_name='effect', value_name='value')

    df_filtered = func_scores_E3_low_effect.melt(id_vars=['site', 'mutant', 'wildtype'],
                                 value_vars=['effect'], 
                                 var_name='effect', value_name='value')
    
    df_test = pd.concat([df_melted,df_filtered,distance_file],ignore_index=True)
    df_test['ab'] = ab
    return df_test

empty_df_m102 = make_empty_df_with_distance("m102.4",distance_m102_df)
empty_df_HENV26 = make_empty_df_with_distance("HENV-26",distance_26_df)
empty_df_HENV32 = make_empty_df_with_distance("HENV-32",distance_32_df)
empty_df_nah = make_empty_df_with_distance("nAH1.3",distance_nah_df)

def make_empty_df(ab):
    sites = range(71, 603)
    amino_acids = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
    # Create the combination of each site with each amino acid
    data = [{'site': site, 'mutant': aa} for site in sites for aa in amino_acids]
    # Create the DataFrame
    empty_df = pd.DataFrame(data)
    all_sites_df = pd.merge(empty_df,combined_df.query(f'ab == "{ab}"'),on=['site','mutant'],how='left')
    df_melted = all_sites_df.melt(id_vars=['site', 'mutant', 'wildtype'],
                                 value_vars=['escape_median'], 
                                 var_name='effect', value_name='value')

    df_filtered = func_scores_E3_low_effect.melt(id_vars=['site', 'mutant', 'wildtype'],
                                 value_vars=['effect'], 
                                 var_name='effect', value_name='value')
    
    df_test = pd.concat([df_melted,df_filtered],ignore_index=True)
    df_test['ab'] = ab
    return df_test

empty_df_HENV117 = make_empty_df("HENV-117")
empty_df_HENV103 = make_empty_df("HENV-103")

combined_ab = pd.concat([empty_df_m102,empty_df_HENV26,empty_df_HENV32,empty_df_nah,empty_df_HENV117,empty_df_HENV103])
display(combined_ab)
site mutant wildtype effect value ab
0 71 A NaN escape_median NaN m102.4
1 71 C Q escape_median 0.294900 m102.4
2 71 D Q escape_median -0.064940 m102.4
3 71 E Q escape_median 0.020610 m102.4
4 71 F Q escape_median 0.006025 m102.4
... ... ... ... ... ... ...
13346 597 S I effect -2.796000 HENV-103
13347 597 T I effect -3.526000 HENV-103
13348 598 C P effect -2.300000 HENV-103
13349 598 W P effect -3.177000 HENV-103
13350 598 Y P effect -2.057000 HENV-103

81836 rows × 6 columns

In [21]:
def plot_distance_only(df,trigger):
    custom_order = ['distance','R','K','H','D','E','Q','N','S','T','Y','W','F','A','I','L','M','V','G','P','C']   
    all_residues = range(71, 603)
    final_df = df
    final_df = final_df.sort_values('site') # Sort the dataframe by 'site' to ensure that duplicates are detected correctly.
    sort_order = {mutant: i for i, mutant in enumerate(custom_order)} # Create a dictionary that maps each mutant to its sort rank based on the custom order
    final_df['mutant_rank'] = final_df['mutant'].map(sort_order) # Map the 'mutant' column to these ranks
    
    final_df = final_df.sort_values('mutant_rank') # Now sort the dataframe by this rank
    final_df = final_df.drop(columns=['mutant_rank']) # Drop the 'mutant_rank' column as it is no longer needed after sorting
    sites = sorted(final_df['site'].unique(), key=lambda x: float(x))
    ab_list = ['m102.4','HENV-26','HENV-117','HENV-103','HENV-32','nAH1.3']
    empty_chart = [] #setup collection for charts
    for idx, ab in enumerate(ab_list):
        tmp_df = final_df[final_df['ab'] == ab]
        if ab == 'm102.4':
            site_subset = m102_combined_sites
            #legend_conditional = alt.Legend(title='Distance to mAb')
        if ab == 'HENV-26':
            site_subset = HENV26_combined_sites
            #legend_conditional = alt.Legend(title='Distance to mAb')
        if ab == 'HENV-32':
            site_subset = HENV32_combined_sites
            #legend_conditional = alt.Legend(title='Distance to mAb')
        if ab == 'HENV-103':
            site_subset = sites_dict['HENV-103']
            #legend_conditional = alt.Legend(title=None)
        if ab == 'HENV-117':
            site_subset = sites_dict['HENV-117']
            #legend_conditional = alt.Legend(title=None)
        if ab == 'nAH1.3':
            site_subset = nah_combined_sites
            #legend_conditional = alt.Legend(title='Distance to mAb')
        
        #select which sites you will show
        if trigger == True:
            tmp_df = tmp_df[tmp_df['site'].isin(site_subset)]
            x_axis = alt.Axis(labelAngle=-90,
                          #labelExpr="datum.value % 10 === 0 ? datum.value : ''",
                          title="Site")
        else:
            tmp_df = tmp_df[tmp_df['site'].isin(all_residues)]
            
            # Conditionally set the x-axis labels and title for the last plot
            is_last_plot = idx == len(ab_list) - 1
            x_axis = alt.Axis(labelAngle=-90,
                          labelExpr="datum.value % 10 === 0 ? datum.value : ''",
                          title="Site" if is_last_plot else None,
                          labels=True)  # Only show labels for the last plot

        # Prepare the color scales separately for distance and effects
        # Filter out 'distance' values before creating the effect heatmap
        effect_df = tmp_df[(tmp_df['mutant'] != 'distance') & (tmp_df['effect'] != 'effect')]
        max_color = effect_df['value'].max()
        min_color = effect_df['value'].min()

        #Adjust color scheme for abs with little sensitizing mutations
        if min_color > -1:
            min_color = min_color - 1
        
        # Prepare the color scale for effects, Altair will automatically determine the domain
        color_scale_escape = alt.Scale(scheme='redblue', domainMid=0,domain=[min_color,max_color])
        color_scale_entropy = alt.Scale(scheme='purples', domain=[0, 15],reverse=True)
    
        strokewidth_size = 0.25
        
        unique_wildtypes_df = tmp_df.drop_duplicates(subset=['site', 'wildtype'])

        # The chart for the heatmap
        base = (
            alt.Chart(tmp_df,title=f'{ab}')
            .encode(
                x=alt.X('site:O', title='Site', sort=sites, axis=x_axis),
                y=alt.Y('mutant', title='Amino Acid', sort=alt.EncodingSortField(field='sort_order', order='ascending'),axis=alt.Axis(grid=False)),  # Apply custom sort order here
                tooltip=['site','wildtype','mutant','value'],
            ).properties(
                width=alt.Step(10),
                height=alt.Step(11) 
            )
        )
        # Heatmap for distance
        chart_empty = (
            base.mark_rect(color='#e6e7e8').encode(
            ).transform_filter(
                alt.datum.effect == 'escape_median'
            )
        )
        # Heatmap for effect
        chart_effect = (
            base.mark_rect(stroke='black',strokeWidth = strokewidth_size).encode(
            color=alt.condition('datum.mutant != "distance"', 
            alt.Color('value:Q', scale=color_scale_escape,legend=alt.Legend(title=f'{ab} Escape')), 
            alt.value('transparent')),
            ).transform_filter(
                alt.datum.effect == 'escape_median'
            )
        )
        
        # Heatmap for distance
        if ab in ['m102.4','HENV-26','HENV-32','nAH1.3']:
            chart_distance = (
                base.mark_rect().encode(
                    color=alt.condition('datum.mutant == "distance"', 
                    alt.Color('value:Q', scale=color_scale_entropy,legend=alt.Legend(title='Distance to mAb')), 
                    alt.value('transparent'))
                ).transform_filter(
                    alt.datum.effect == 'escape_median'
                )
            )
        else:
            chart_distance = (
                base.mark_rect(color='transparent').encode(
                    #color=alt.Color('white'), 
                    #alt.Color('value:Q', scale=color_scale_entropy,legend=alt.Legend(title='Distance to mAb')), 
                    #alt.value('transparent'))
                ).transform_filter(
                    alt.datum.effect == 'escape_median'
                )
            )
        # Heatmap for distance
        chart_filtered = (
            base.mark_rect(color='#939598',stroke='black',strokeWidth = strokewidth_size).encode(
            ).transform_filter(
                alt.datum.effect == 'effect'
            )
        )
        
        # The layer for the wildtype boxes
        wildtype_layer_box = (
            alt.Chart(unique_wildtypes_df).mark_rect(color='white',stroke='black',strokeWidth = strokewidth_size).encode(
                x=alt.X('site:O', sort=sites),
                y=alt.Y('wildtype', sort=alt.EncodingSortField(field='sort_order', order='ascending')),
                opacity=alt.value(1)  
            )
            .transform_filter(
                (alt.datum.wildtype != '') & (alt.datum.wildtype != None) & (alt.datum.value != None)
            )
        )
        # The layer for the wildtype amino acids
        wildtype_layer = (
            alt.Chart(unique_wildtypes_df).mark_text(color='black', text='X', size=8).encode(
                x=alt.X('site:O', sort=sites),
                y=alt.Y('wildtype', sort=alt.EncodingSortField(field='sort_order', order='ascending')),
                opacity=alt.value(1)  
            )
            .transform_filter(
                (alt.datum.wildtype != '') & (alt.datum.wildtype != None) & (alt.datum.value != None)
            )
        )    
        
        # Combine the heatmap layer with the wildtype layer
        chart = alt.layer(chart_empty,chart_effect,chart_distance,chart_filtered,wildtype_layer_box,wildtype_layer).resolve_scale(color='independent')
        empty_chart.append(chart)
    combined_chart = alt.vconcat(*empty_chart, spacing=1).resolve_scale(y='shared', x='independent', color='independent').configure_title(
    anchor='start',  # Aligns the title to the left ('middle' for center, 'end' for right)
    offset=10,  # Adjusts the distance of the title from the chart
    orient='top',  # Positions the title at the top; use 'bottom' to position at the bottom
    )
    return combined_chart


mab_plot = plot_distance_only(combined_ab,True)
mab_plot.display()
mab_plot.save(mab_plot_top)

Make full antibody escape heatmaps¶

In [22]:
mab_all = plot_distance_only(combined_ab,False)
mab_all.display()
mab_all.save(mab_plot_all)

Now make heatmaps of antibody escape versus Ephrin Binding¶

First prepare data:

In [23]:
bind_df = pd.read_csv(binding_data)
binding_df = bind_df.groupby('site')['binding_median'].median().reset_index()

def make_empty_binding():
    sites = range(71, 603)
    data = [{'site': site} for site in sites]
    empty_df = pd.DataFrame(data)
    empty_df = pd.merge(empty_df,binding_df,on='site',how='left')
    empty_df = empty_df.rename(columns={'binding_median':'value'})
    empty_df['effect'] = 'escape_median'
    empty_df['ab'] = 'Ephrin-B2 binding'
    return(empty_df)

binding_empty = make_empty_binding()

escape_df = combined_df.groupby(['ab','site'])['escape_median'].median().reset_index()
def make_empty_df(ab):
    sites = range(71, 603)
    data = [{'site': site} for site in sites]
    
    # Create the DataFrame
    empty_df = pd.DataFrame(data)

    all_sites_df = pd.merge(empty_df,escape_df.query(f'ab == "{ab}"'),on=['site'],how='left')
    
    df_melted = all_sites_df.melt(id_vars=['site'],
                                 value_vars=['escape_median'], 
                                 var_name='effect', value_name='value')
    
    df_test = pd.concat([df_melted],ignore_index=True)
    df_test['ab'] = ab
    return df_test

ab_list = ['m102.4', 'HENV-26', 'HENV-117', 'HENV-103', 'HENV-32', 'nAH1.3']
#ab_list = ['HENV-32']

empty = []
for ab in ab_list:
    tmp_df = make_empty_df(ab)
    empty.append(tmp_df)
all_empties_df = pd.concat(empty,ignore_index=True)
all_empties_df = pd.concat([all_empties_df,binding_empty])
display(all_empties_df)
site effect value ab
0 71 escape_median 0.006471 m102.4
1 72 escape_median 0.029320 m102.4
2 73 escape_median 0.045520 m102.4
3 74 escape_median -0.017160 m102.4
4 75 escape_median -0.003845 m102.4
... ... ... ... ...
527 598 escape_median 0.133200 Ephrin-B2 binding
528 599 escape_median -0.159500 Ephrin-B2 binding
529 600 escape_median -0.014710 Ephrin-B2 binding
530 601 escape_median 0.193800 Ephrin-B2 binding
531 602 escape_median -0.082220 Ephrin-B2 binding

3724 rows × 4 columns

In [24]:
def make_heatmap_with_binding(df):
    # Define the custom sort order directly in the encoding
    sort_order = ['NiV Polymorphism','Ephrin-B2 binding', 'm102.4', 'HENV-26', 'HENV-117', 'HENV-103', 'HENV-32', 'nAH1.3']
    full_ranges = [list(range(start, end)) for start, end in [(71, 204), (204, 337), (337, 470), (470, 603)]]

    # container to hold the charts
    charts = [] 
    color_scale_effect = alt.Scale(scheme='redblue', domainMid=0, domain=[-1, 1])
    color_scale_binding = alt.Scale(scheme='redblue', domainMid=0, domain=[-5, 2])

    
    # Flags for showing the legend only the first time
    effect_legend_added = True
    binding_legend_added = True
    for idx, subset in enumerate(full_ranges): 
        subset_df = df[df['site'].isin(subset)] #for the wrapping of sites
        is_last_plot = idx == len(full_ranges) - 1
        x_axis = alt.Axis(labelAngle=-90,
                        labelExpr="datum.value % 10 === 0 ? datum.value : ''",
                        title="Site" if is_last_plot else None,
                        labels=True)  # Only show labels for the last plot
    
        base = alt.Chart(subset_df).encode(
            x=alt.X('site:O', title='Site',axis=x_axis),
            y=alt.Y('ab', title=None, sort=sort_order, axis=alt.Axis(grid=False)),  # Correctly apply custom sort order
            tooltip=['site', 'value'],
        ).properties(
            width=alt.Step(10),
            height=alt.Step(11)
        )
        
        # Define the chart for empty cells
        chart_empty = base.mark_rect(color='#e6e7e8').transform_filter(
            alt.datum.effect == 'escape_median'
        )
        if not effect_legend_added:
            # Define the chart for cells with effect
            chart_effect = base.mark_rect(stroke='black', strokeWidth=0.25).encode(
                color=alt.condition(
                    'datum.effect == "escape_median"', 
                    alt.Color('value:Q', scale=color_scale_effect),  # Define a color scale
                    alt.value('transparent')
                )
            ).transform_filter(
                alt.datum.effect == 'escape_median'
            )
            effect_legend_added = True
        else:
            # Define the chart for cells with effect
            chart_effect = base.mark_rect(stroke='black', strokeWidth=0.25).encode(
                color=alt.condition(
                    'datum.effect == "escape_median"', 
                    alt.Color('value:Q', scale=color_scale_effect,legend=None),  # Define a color scale
                    alt.value('transparent')
                )
            ).transform_filter(
                alt.datum.effect == 'escape_median'
            )
        if not binding_legend_added:
            chart_binding = base.mark_rect(strokeWidth=1.1).encode(
                stroke = alt.value('value'),
                
                color=alt.condition(
                    'datum.effect == "escape_median"', 
                    alt.Color('value:Q', scale=color_scale_binding),
                    alt.value('transparent')
                )
            ).transform_filter(
                alt.datum.ab == 'Ephrin-B2 binding'
            )
            binding_legend_added = True
        else:
            chart_binding = base.mark_rect(strokeWidth=1.1).encode(
                stroke = alt.value('value'),
                color=alt.condition(
                    'datum.effect == "escape_median"', 
                    alt.Color('value:Q', scale=color_scale_binding,legend=None),
                    alt.value('transparent')
                )
            ).transform_filter(
                alt.datum.ab == 'Ephrin-B2 binding'
            )
        
        chart_poly = base.mark_rect(color='black').encode().transform_filter(
            alt.datum.ab == 'NiV Polymorphism'
        )
        # Layer the charts using `layer` instead of `+`
        chart = alt.layer(chart_empty, chart_effect,chart_binding,chart_poly).resolve_scale(color='independent')
        charts.append(chart)
    combined_chart = alt.vconcat(*charts, spacing=5,title='Heatmap of median mAb escape and Ephrin-B2 binding').resolve_scale(y='shared', x='independent', color='shared')
        
    return combined_chart

# Assuming `all_empties_df` is your DataFrame and already defined
chart = make_heatmap_with_binding(all_empties_df)
chart.display()
chart.save(aggregate_mab_and_binding)

Now show heatmap with nipah polymorphisms¶

In [25]:
def make_contact():
    df = pd.DataFrame({
    'site': niv_poly,
    'contact': [0.0] * len(niv_poly)
    })
    df = df[['site','contact']]
    #df['mutant'] = 'contact'
    df['ab'] = 'NiV Polymorphism'
    df['effect'] = 'median_escape'
    df.rename(columns={'contact':'value'}, inplace=True)
    return df

niv_poly = [82, 89, 135, 172, 228, 236, 274, 288, 299, 325, 328, 329, 335, 339, 344, 376, 384, 385, 386, 421, 423, 424, 426, 427, 470, 478, 481, 498, 502, 545]
contact_df = make_contact()

bind_df = pd.read_csv('results/filtered_data/E2_binding_filtered.csv')
binding_df = bind_df.groupby('site')['binding_median'].median().reset_index()

def make_empty_binding():
    sites = range(71, 603)
    data = [{'site': site} for site in sites]
    empty_df = pd.DataFrame(data)
    empty_df = pd.merge(empty_df,binding_df,on='site',how='left')
    empty_df = empty_df.rename(columns={'binding_median':'value'})
    empty_df['effect'] = 'escape_median'
    empty_df['ab'] = 'Ephrin-B2 binding'
    return(empty_df)

binding_empty = make_empty_binding()

escape_df = combined_df.groupby(['ab','site'])['escape_median'].median().reset_index()
def make_empty_df(ab):
    sites = range(71, 603)
    data = [{'site': site} for site in sites]
    
    # Create the DataFrame
    empty_df = pd.DataFrame(data)

    all_sites_df = pd.merge(empty_df,escape_df.query(f'ab == "{ab}"'),on=['site'],how='left')
    
    df_melted = all_sites_df.melt(id_vars=['site'],
                                 value_vars=['escape_median'], 
                                 var_name='effect', value_name='value')
    
    df_test = pd.concat([df_melted],ignore_index=True)
    df_test['ab'] = ab
    return df_test

ab_list = ['m102.4', 'HENV-26', 'HENV-117', 'HENV-103', 'HENV-32', 'nAH1.3']
#ab_list = ['HENV-32']

empty = []
for ab in ab_list:
    tmp_df = make_empty_df(ab)
    empty.append(tmp_df)
all_empties_df = pd.concat(empty,ignore_index=True)
all_empties_df = pd.concat([all_empties_df,contact_df])
display(all_empties_df)
site effect value ab
0 71 escape_median 0.006471 m102.4
1 72 escape_median 0.029320 m102.4
2 73 escape_median 0.045520 m102.4
3 74 escape_median -0.017160 m102.4
4 75 escape_median -0.003845 m102.4
... ... ... ... ...
25 478 median_escape 0.000000 NiV Polymorphism
26 481 median_escape 0.000000 NiV Polymorphism
27 498 median_escape 0.000000 NiV Polymorphism
28 502 median_escape 0.000000 NiV Polymorphism
29 545 median_escape 0.000000 NiV Polymorphism

3222 rows × 4 columns

In [26]:
def make_heatmap_with_polymorphisms(df):
    # Define the custom sort order directly in the encoding
    sort_order = ['NiV Polymorphism', 'm102.4', 'HENV-26', 'HENV-117', 'HENV-103', 'HENV-32', 'nAH1.3']
    #full_ranges = [list(range(start, end)) for start, end in [(71, 204), (204, 337), (337, 470), (470, 603)]]
    full_ranges = [list(range(start, end)) for start, end in [(71, 204), (204, 337), (337, 470), (470, 603)]]

    # container to hold the charts
    charts = [] 
    color_scale_effect = alt.Scale(scheme='redblue', domainMid=0, domain=[-1, 1])
    color_scale_binding = alt.Scale(scheme='redblue', domainMid=0, domain=[-5, 2])

    
    # Flags for showing the legend only the first time
    effect_legend_added = True
    binding_legend_added = True
    for idx, subset in enumerate(full_ranges): 
        subset_df = df[df['site'].isin(subset)] #for the wrapping of sites
        is_last_plot = idx == len(full_ranges) - 1
        x_axis = alt.Axis(labelAngle=-90,
                        labelExpr="datum.value % 10 === 0 ? datum.value : ''",
                        title="Site" if is_last_plot else None,
                        labels=True)  # Only show labels for the last plot
    
        base = alt.Chart(subset_df).encode(
            x=alt.X('site:O', title='Site',axis=x_axis),
            y=alt.Y('ab', title=None, sort=sort_order, axis=alt.Axis(grid=False)),  # Correctly apply custom sort order
            tooltip=['site',alt.Tooltip('value',format=".2f")],
        ).properties(
            width=alt.Step(10),
            height=alt.Step(11)
        )
        
        # Define the chart for empty cells
        chart_empty = base.mark_rect(color='#e6e7e8').transform_filter(
            alt.datum.effect == 'escape_median'
        )
        if not effect_legend_added:
            # Define the chart for cells with effect
            chart_effect = base.mark_rect(stroke='black', strokeWidth=0.25).encode(
                color=alt.condition(
                    'datum.effect == "escape_median"', 
                    alt.Color('value:Q', scale=color_scale_effect),  # Define a color scale
                    alt.value('transparent')
                )
            ).transform_filter(
                alt.datum.effect == 'escape_median'
            )
            effect_legend_added = True
        else:
            # Define the chart for cells with effect
            chart_effect = base.mark_rect(stroke='black', strokeWidth=0.25).encode(
                color=alt.condition(
                    'datum.effect == "escape_median"', 
                    alt.Color('value:Q', scale=color_scale_effect,legend=None),  # Define a color scale
                    alt.value('transparent')
                )
            ).transform_filter(
                alt.datum.effect == 'escape_median'
            )
        if not binding_legend_added:
            chart_binding = base.mark_rect(strokeWidth=1.1).encode(
                stroke = alt.value('value'),
                
                color=alt.condition(
                    'datum.effect == "escape_median"', 
                    alt.Color('value:Q', scale=color_scale_binding),
                    alt.value('transparent')
                )
            ).transform_filter(
                alt.datum.ab == 'Ephrin-B2 binding'
            )
            binding_legend_added = True
        else:
            chart_binding = base.mark_rect(strokeWidth=1.1).encode(
                stroke = alt.value('value'),
                color=alt.condition(
                    'datum.effect == "escape_median"', 
                    alt.Color('value:Q', scale=color_scale_binding,legend=None),
                    alt.value('transparent')
                )
            ).transform_filter(
                alt.datum.ab == 'Ephrin-B2 binding'
            )
        
        chart_poly = base.mark_rect(color='black').encode().transform_filter(
            alt.datum.ab == 'NiV Polymorphism'
        )
        # Layer the charts using `layer` instead of `+`
        chart = alt.layer(chart_empty, chart_effect,chart_poly).resolve_scale(color='independent')
        charts.append(chart)
    combined_chart = alt.vconcat(*charts, spacing=5,title='Heatmap of median mAb escape and Nipah Polymorphisms').resolve_scale(y='shared', x='independent', color='shared')
        
    return combined_chart

# Assuming `all_empties_df` is your DataFrame and already defined
chart = make_heatmap_with_polymorphisms(all_empties_df)
chart.display()
chart.save(aggregate_mab_and_niv_polymorphism)

Make plots comparing escape with binding to see if escape sites do so by increasing binding¶

In [27]:
new_merged_df = pd.merge(combined_df,bind_df[['site','wildtype','mutant','binding_median']],on=['site','wildtype','mutant'],how='left')
ab_list = ['m102.4', 'HENV-26', 'HENV-117', 'HENV-103', 'HENV-32', 'nAH1.3']
ab_list1 = ['m102.4', 'HENV-26', 'HENV-117']
ab_list2 = ['HENV-103', 'HENV-32']
ab_list3 = ['nAH1.3']
def plot_escape_vs_binding(df):
    variant_selector = alt.selection_point(
        on="mouseover",
        empty=False,
        fields=["site"],
        value=1
    )  
    empty_chart1=[]
    for ab in ab_list1:
        tmp_df = df[df['ab'] == ab]
        base = alt.Chart(tmp_df,title=f'{ab}').mark_point(filled=True,size=10,color='black',opacity=0.15).encode(
            alt.X('binding_median',title='Binding',axis=alt.Axis(grid=True,tickCount=3)),  
            alt.Y('escape_median',title='Ab Escape',axis=alt.Axis(grid=True,tickCount=3),scale=alt.Scale(scheme='redblue')), 
            tooltip=['site','wildtype','mutant','escape_median','binding_median'],
        ).properties(
            width=200,
            height=200
        )
        empty_chart1.append(base)
    combined_chart1 = alt.hconcat(*empty_chart1, spacing=5).resolve_scale(x='shared',y='shared')
    empty_chart2 = []
    for ab in ab_list2:
        tmp_df = df[df['ab'] == ab]
        base = alt.Chart(tmp_df,title=f'{ab}').mark_point(filled=True,size=10,color='black',opacity=0.15).encode(
            alt.X('binding_median',title='Binding',axis=alt.Axis(grid=True,tickCount=3)),  
            alt.Y('escape_median',title='Ab Escape',axis=alt.Axis(grid=True,tickCount=3),scale=alt.Scale(scheme='redblue')), 
            tooltip=['site','wildtype','mutant','escape_median','binding_median'],
        ).properties(
            width=200,
            height=200
        )
        empty_chart2.append(base)
    combined_chart2 = alt.hconcat(*empty_chart2, spacing=5).resolve_scale(x='shared',y='shared')

    empty_chart3 = []
    for ab in ab_list3:
        tmp_df = df[df['ab'] == ab]
        base3 = alt.Chart(tmp_df,title=f'{ab}').mark_point(filled=True,size=10,color='black',opacity=0.15).encode(
            alt.X('binding_median',title='Binding',axis=alt.Axis(grid=True,tickCount=3)),  
            alt.Y('escape_median',title='Ab Escape',axis=alt.Axis(grid=True,tickCount=3),scale=alt.Scale(scheme='redblue')), 
            tooltip=['site','wildtype','mutant','escape_median','binding_median'],
        ).properties(
            width=200,
            height=200
        )
    
    combined_chart_total = alt.vconcat(combined_chart1,combined_chart2,base3).configure_title(anchor='middle', fontSize=16)
    return combined_chart_total

tmp_img_test = plot_escape_vs_binding(new_merged_df)
tmp_img_test.display()
tmp_img_test.save(binding_vs_escape)
In [ ]: